import os
from collections.abc import Generator

import pytest

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
    AssistantPromptMessage,
    PromptMessageTool,
    SystemPromptMessage,
    TextPromptMessageContent,
    UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.chatglm.llm.llm import ChatGLMLargeLanguageModel
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock


def test_predefined_models():
    model = ChatGLMLargeLanguageModel()
    model_schemas = model.predefined_models()
    assert len(model_schemas) >= 1
    assert isinstance(model_schemas[0], AIModelEntity)

@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
def test_validate_credentials_for_chat_model(setup_openai_mock):
    model = ChatGLMLargeLanguageModel()

    with pytest.raises(CredentialsValidateFailedError):
        model.validate_credentials(
            model='chatglm2-6b',
            credentials={
                'api_base': 'invalid_key'
            }
        )

    model.validate_credentials(
        model='chatglm2-6b',
        credentials={
            'api_base': os.environ.get('CHATGLM_API_BASE')
        }
    )

@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
def test_invoke_model(setup_openai_mock):
    model = ChatGLMLargeLanguageModel()

    response = model.invoke(
        model='chatglm2-6b',
        credentials={
            'api_base': os.environ.get('CHATGLM_API_BASE')
        },
        prompt_messages=[
            SystemPromptMessage(
                content='You are a helpful AI assistant.',
            ),
            UserPromptMessage(
                content='Hello World!'
            )
        ],
        model_parameters={
            'temperature': 0.7,
            'top_p': 1.0,
        },
        stop=['you'],
        user="abc-123",
        stream=False
    )

    assert isinstance(response, LLMResult)
    assert len(response.message.content) > 0
    assert response.usage.total_tokens > 0

@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
def test_invoke_stream_model(setup_openai_mock):
    model = ChatGLMLargeLanguageModel()

    response = model.invoke(
        model='chatglm2-6b',
        credentials={
            'api_base': os.environ.get('CHATGLM_API_BASE')
        },
        prompt_messages=[
            SystemPromptMessage(
                content='You are a helpful AI assistant.',
            ),
            UserPromptMessage(
                content='Hello World!'
            )
        ],
        model_parameters={
            'temperature': 0.7,
            'top_p': 1.0,
        },
        stop=['you'],
        stream=True,
        user="abc-123"
    )

    assert isinstance(response, Generator)
    for chunk in response:
        assert isinstance(chunk, LLMResultChunk)
        assert isinstance(chunk.delta, LLMResultChunkDelta)
        assert isinstance(chunk.delta.message, AssistantPromptMessage)
        assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True

@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
def test_invoke_stream_model_with_functions(setup_openai_mock):
    model = ChatGLMLargeLanguageModel()

    response = model.invoke(
        model='chatglm3-6b',
        credentials={
            'api_base': os.environ.get('CHATGLM_API_BASE')
        },
        prompt_messages=[
            SystemPromptMessage(
                content='你是一个天气机器人，你不知道今天的天气怎么样，你需要通过调用一个函数来获取天气信息。'
            ),
            UserPromptMessage(
                content='波士顿天气如何？'
            )
        ],
        model_parameters={
            'temperature': 0,
            'top_p': 1.0,
        },
        stop=['you'],
        user='abc-123',
        stream=True,
        tools=[
            PromptMessageTool(
                name='get_current_weather',
                description='Get the current weather in a given location',
                parameters={
                    "type": "object",
                    "properties": {
                        "location": {
                        "type": "string",
                            "description": "The city and state e.g. San Francisco, CA"
                        },
                        "unit": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"]
                        }
                    },
                    "required": [
                        "location"
                    ]
                }
            )
        ]
    )

    assert isinstance(response, Generator)
    
    call: LLMResultChunk = None
    chunks = []

    for chunk in response:
        chunks.append(chunk)
        assert isinstance(chunk, LLMResultChunk)
        assert isinstance(chunk.delta, LLMResultChunkDelta)
        assert isinstance(chunk.delta.message, AssistantPromptMessage)
        assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True

        if chunk.delta.message.tool_calls and len(chunk.delta.message.tool_calls) > 0:
            call = chunk
            break

    assert call is not None
    assert call.delta.message.tool_calls[0].function.name == 'get_current_weather'

@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
def test_invoke_model_with_functions(setup_openai_mock):
    model = ChatGLMLargeLanguageModel()

    response = model.invoke(
        model='chatglm3-6b',
        credentials={
            'api_base': os.environ.get('CHATGLM_API_BASE')
        },
        prompt_messages=[
            UserPromptMessage(
                content='What is the weather like in San Francisco?'
            )
        ],
        model_parameters={
            'temperature': 0.7,
            'top_p': 1.0,
        },
        stop=['you'],
        user='abc-123',
        stream=False,
        tools=[
            PromptMessageTool(
                name='get_current_weather',
                description='Get the current weather in a given location',
                parameters={
                    "type": "object",
                    "properties": {
                        "location": {
                        "type": "string",
                            "description": "The city and state e.g. San Francisco, CA"
                        },
                        "unit": {
                            "type": "string",
                            "enum": [
                                "c",
                                "f"
                            ]
                        }
                    },
                    "required": [
                        "location"
                    ]
                }
            )
        ]
    )

    assert isinstance(response, LLMResult)
    assert len(response.message.content) > 0
    assert response.usage.total_tokens > 0
    assert response.message.tool_calls[0].function.name == 'get_current_weather'


def test_get_num_tokens():
    model = ChatGLMLargeLanguageModel()

    num_tokens = model.get_num_tokens(
        model='chatglm2-6b',
        credentials={
            'api_base': os.environ.get('CHATGLM_API_BASE')
        },
        prompt_messages=[
            SystemPromptMessage(
                content='You are a helpful AI assistant.',
            ),
            UserPromptMessage(
                content='Hello World!'
            )
        ],
        tools=[
            PromptMessageTool(
                name='get_current_weather',
                description='Get the current weather in a given location',
                parameters={
                    "type": "object",
                    "properties": {
                        "location": {
                        "type": "string",
                            "description": "The city and state e.g. San Francisco, CA"
                        },
                        "unit": {
                            "type": "string",
                            "enum": [
                                "c",
                                "f"
                            ]
                        }
                    },
                    "required": [
                        "location"
                    ]
                }
            )
        ]
    )

    assert isinstance(num_tokens, int)
    assert num_tokens == 77

    num_tokens = model.get_num_tokens(
        model='chatglm2-6b',
        credentials={
            'api_base': os.environ.get('CHATGLM_API_BASE')
        },
        prompt_messages=[
            SystemPromptMessage(
                content='You are a helpful AI assistant.',
            ),
            UserPromptMessage(
                content='Hello World!'
            )
        ],
    )

    assert isinstance(num_tokens, int)
    assert num_tokens == 21